#coding = utf-8
import tensorflow as tf
import tensorflow.contrib.layers
import numpy as np

def conv2d(inputs, filters, kernel_size=1, strides=1, pad='SAME', name='None'):
    with tf.name_scope(name):
        kernel = tf.Variable(tf.contrib.layers.xavier_initializer(uniform=False)(
            [kernel_size,kernel_size,inputs.get_shape().as_list()[3], filters]), name='weights')
        conv = tf.nn.conv2d(inputs, kernel, [1,strides,strides,1], padding='SAME', data_format='NHWC')
        return conv

def conv_bn_relu(inputs, filters, kernel_size=1, strides=1, pad='SAME', name='conv_bn_relu'):
    kernel = tf.Variable(tf.contrib.layers.xavier_initializer(uniform=False)(
        [kernel_size, kernel_size, inputs.get_shape().as_list()[3], filters]), name='weights')
    conv = tf.nn.conv2d(inputs, kernel, [1,strides,strides,1], padding='SAME', data_format='NHWC')
    norm = tf.contrib.layers.batch_norm(conv, 0.9, epsilon=1e-5, activation_fn=tf.nn.relu)
    return norm

def conv_block(inputs, numOut, name='conv_block'):
    with tf.name_scope('norm_1'):
        norm_1 = tf.contrib.layers.batch_norm(inputs, 0.9, epsilon=1e-5, activation_fn=tf.nn.relu)
        conv_1 = conv2d(norm_1, int(numOut / 2), kernel_size=1, strides=1, pad='SAME', name='conv')
    with tf.name_scope('norm_2'):
        norm_2 = tf.contrib.layers.batch_norm(conv_1, 0.9, epsilon=1e-5, activation_fn=tf.nn.relu)
        conv_2 = conv2d(norm_2, int(numOut / 2), kernel_size=1, strides=1, pad='SAME', name='conv')
    with tf.name_scope('norm_3'):
        norm_3 = tf.contrib.layers.batch_norm(conv_2, 0.9, epsilon=1e-5, activation_fn=tf.nn.relu)
        conv_3 = conv2d(norm_3, int(numOut), kernel_size=1, strides=1, pad='SAME', name='conv')
    return conv_3

#Residual Model

def skip_layer(inputs, numOut):
    with tf.name_scope('skip_layer'):
        if(inputs.get_shape().as_list()[3]==numOut):
            return inputs
        else:
            conv = conv2d(inputs, numOut, kernel_size=1, strides=1, pad='SAME', name='conv')
            return conv

def residual(inputs, numOut, name='residual_block'):
    convb = conv_block(inputs, numOut)
    skip = skip_layer(inputs, numOut)
    return tf.add_n([convb, skip], name='res_block')

def hourglass(inputs, n, numOut, name='hourglass'):
    with tf.name_scope(name):
        up_1 = residual(inputs, numOut, name='up_1')
        low_ = tf.nn.max_pool(inputs, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
        low_1 = residual(low_, numOut, name='low_1')
        if(n>0):
            low_2 = hourglass(low_1, n-1, numOut, name='low_2')
        else:
            low_2 = residual(low_1, numOut, name='low_2')
        low_3 = residual(low_2, numOut, name='low_3')
        low_4 = tf.image.resize_nearest_neighbor(low_3, tf.shape(up_1)[1:3], name='upsamping')
        return tf.add_n([low_4, up_1], name='out_hg')

#mnist 28*28
def create_variable(name, shape, initializer,
    dtype=tf.float32, trainable=True):
    return tf.get_variable(name, shape=shape, dtype=dtype,
            initializer=initializer, trainable=trainable)

def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

def fc(inputs, n_out, use_bias=True):
    inputs_shape = inputs.get_shape().as_list()
    n_in = inputs_shape[-1]
    with tf.name_scope('fc'):
        weight = create_variable("weight", shape=[n_in, n_out],
                    initializer=tf.random_normal_initializer(stddev=0.01))
        if use_bias:
            bias = create_variable("bias", shape=[n_out,],
                                   initializer=tf.zeros_initializer())
            return tf.nn.xw_plus_b(inputs, weight, bias)
        return tf.matmul(inputs, weight)


class Hourglass(object):
    def __init__(self, inputs, labels, num_classes=10, is_training=True, scope="HourglassNet"):
        self.inputs = inputs
        self.num_classes = 10
        self.is_training = is_training
        with tf.variable_scope(scope):
            net = conv_bn_relu(inputs, filters=64, kernel_size=6, strides=2)
            net = residual(net, numOut=128, name='r1')
            net = tf.contrib.layers.max_pool2d(net, [2,2], [2,2], padding='SAME')
            net = hourglass(net, n=2, numOut=64)
            net = tf.reshape(net, [-1,7*7*64])
            W_fc1 = weight_variable([7 * 7 * 64, 1024])
            b_fc1 = bias_variable([1024])
            net = tf.nn.relu(tf.matmul(net,W_fc1)+b_fc1)
            self.logits = fc(net, self.num_classes)
            self.predictions = tf.nn.softmax(self.logits)
            self.loss = -tf.reduce_mean(labels * tf.log(self.predictions))

if __name__ == '__main__':
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    x = tf.placeholder("float", [None, 784])
    y_ = tf.placeholder("float", [None, 10])
    x_image = tf.reshape(x, [-1,28,28,1])
    x_lable = tf.reshape(y_, [-1,10])
    HourglassNet = Hourglass(x_image, x_lable)
    train_step = tf.train.AdamOptimizer(1e-4).minimize(HourglassNet.loss)
    keep_prob = tf.placeholder("float")
    correct_prediction = tf.equal(tf.arg_max(HourglassNet.predictions, 1), tf.arg_max(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    init = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init)
        for i in range(10000):
            batch = mnist.train.next_batch(50)
            if i % 10 == 0:
                train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})
                print("step %d, training accurary %g" % (i, train_accuracy))
            train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})


